{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Habitat Lab: Topdown Map Visualization\n",
    "\n",
    "## Initial setup and imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# [setup]\n",
    "import os\n",
    "from typing import TYPE_CHECKING, Union, cast\n",
    "\n",
    "import git\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import habitat\n",
    "from habitat.config.default_structured_configs import (\n",
    "    CollisionsMeasurementConfig,\n",
    "    FogOfWarConfig,\n",
    "    TopDownMapMeasurementConfig,\n",
    ")\n",
    "from habitat.core.agent import Agent\n",
    "from habitat.tasks.nav.nav import NavigationEpisode, NavigationGoal\n",
    "from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower\n",
    "from habitat.utils.visualizations import maps\n",
    "from habitat.utils.visualizations.utils import (\n",
    "    images_to_video,\n",
    "    observations_to_image,\n",
    "    overlay_frame,\n",
    ")\n",
    "from habitat_sim.utils import viz_utils as vut\n",
    "\n",
    "# Quiet the Habitat simulator logging\n",
    "os.environ[\"MAGNUM_LOG\"] = \"quiet\"\n",
    "os.environ[\"HABITAT_SIM_LOG\"] = \"quiet\"\n",
    "\n",
    "if TYPE_CHECKING:\n",
    "    from habitat.core.simulator import Observations\n",
    "    from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim\n",
    "\n",
    "repo = git.Repo(\".\", search_parent_directories=True)\n",
    "dir_path = repo.working_tree_dir\n",
    "data_path = os.path.join(dir_path, \"data\")\n",
    "output_path = os.path.join(\n",
    "    dir_path, \"examples/tutorials/habitat_lab_visualization/\"\n",
    ")\n",
    "os.makedirs(output_path, exist_ok=True)\n",
    "os.chdir(dir_path)\n",
    "# [/setup]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download (testing) 3D scenes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path {data_path} --no-replace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download point-goal navigation episodes for the test scenes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "!python -m habitat_sim.utils.datasets_download --uids habitat_test_pointnav_dataset --data-path {data_path} --no-replace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# [example_1]\n",
    "def example_pointnav_draw_target_birdseye_view():\n",
    "    # Define NavigationEpisode parameters\n",
    "    goal_radius = 0.5\n",
    "    goal = NavigationGoal(position=[10, 0.25, 10], radius=goal_radius)\n",
    "    agent_position = [0, 0.25, 0]\n",
    "    agent_rotation = -np.pi / 4\n",
    "\n",
    "    # Create dummy episode for birdseye view visualization\n",
    "    dummy_episode = NavigationEpisode(\n",
    "        goals=[goal],\n",
    "        episode_id=\"dummy_id\",\n",
    "        scene_id=\"dummy_scene\",\n",
    "        start_position=agent_position,\n",
    "        start_rotation=agent_rotation,  # type: ignore[arg-type]\n",
    "    )\n",
    "\n",
    "    agent_position = np.array(agent_position)\n",
    "    # Draw birdseye view\n",
    "    target_image = maps.pointnav_draw_target_birdseye_view(\n",
    "        agent_position,\n",
    "        agent_rotation,\n",
    "        np.asarray(dummy_episode.goals[0].position),\n",
    "        goal_radius=dummy_episode.goals[0].radius,\n",
    "        agent_radius_px=25,\n",
    "    )\n",
    "    plt.imshow(target_image)\n",
    "    plt.title(\"pointnav_target_image.png\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# [/example_1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# [example_2]\n",
    "def example_pointnav_draw_target_birdseye_view_agent_on_border():\n",
    "    # Define NavigationGoal\n",
    "    goal_radius = 0.5\n",
    "    goal = NavigationGoal(position=[0, 0.25, 0], radius=goal_radius)\n",
    "    # For defined goal create 4 NavigationEpisodes\n",
    "    # with agent being placed on different borders,\n",
    "    # draw birdseye view for each episode and save image to disk\n",
    "    ii = 0\n",
    "    for x_edge in [-1, 0, 1]:\n",
    "        for y_edge in [-1, 0, 1]:\n",
    "            if not np.bitwise_xor(x_edge == 0, y_edge == 0):\n",
    "                continue\n",
    "            ii += 1\n",
    "            agent_position = [7.8 * x_edge, 0.25, 7.8 * y_edge]\n",
    "            agent_rotation = np.pi / 2\n",
    "\n",
    "            dummy_episode = NavigationEpisode(\n",
    "                goals=[goal],\n",
    "                episode_id=\"dummy_id\",\n",
    "                scene_id=\"dummy_scene\",\n",
    "                start_position=agent_position,\n",
    "                start_rotation=agent_rotation,  # type: ignore[arg-type]\n",
    "            )\n",
    "\n",
    "            agent_position = np.array(agent_position)\n",
    "            target_image = maps.pointnav_draw_target_birdseye_view(\n",
    "                agent_position,\n",
    "                agent_rotation,\n",
    "                np.asarray(dummy_episode.goals[0].position),\n",
    "                goal_radius=dummy_episode.goals[0].radius,\n",
    "                agent_radius_px=25,\n",
    "            )\n",
    "            plt.imshow(target_image)\n",
    "            plt.title(\"pointnav_target_image_edge_%d.png\" % ii)\n",
    "            plt.show()\n",
    "\n",
    "\n",
    "# [/example_2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# [example_3]\n",
    "def example_get_topdown_map():\n",
    "    # Create habitat config\n",
    "    config = habitat.get_config(\n",
    "        config_path=os.path.join(\n",
    "            dir_path,\n",
    "            \"habitat-lab/habitat/config/benchmark/nav/pointnav/pointnav_habitat_test.yaml\",\n",
    "        )\n",
    "    )\n",
    "    # Create dataset\n",
    "    dataset = habitat.make_dataset(\n",
    "        id_dataset=config.habitat.dataset.type, config=config.habitat.dataset\n",
    "    )\n",
    "    # Create simulation environment\n",
    "    with habitat.Env(config=config, dataset=dataset) as env:\n",
    "        # Load the first episode\n",
    "        env.reset()\n",
    "        # Generate topdown map\n",
    "        top_down_map = maps.get_topdown_map_from_sim(\n",
    "            cast(\"HabitatSim\", env.sim), map_resolution=1024\n",
    "        )\n",
    "        recolor_map = np.array(\n",
    "            [[255, 255, 255], [128, 128, 128], [0, 0, 0]], dtype=np.uint8\n",
    "        )\n",
    "        # By default, `get_topdown_map_from_sim` returns image\n",
    "        # containing 0 if occupied, 1 if unoccupied, and 2 if border\n",
    "        # The line below recolors returned image so that\n",
    "        # occupied regions are colored in [255, 255, 255],\n",
    "        # unoccupied in [128, 128, 128] and border is [0, 0, 0]\n",
    "        top_down_map = recolor_map[top_down_map]\n",
    "        plt.imshow(top_down_map)\n",
    "        plt.title(\"top_down_map.png\")\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "# [/example_3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# [example_4]\n",
    "class ShortestPathFollowerAgent(Agent):\n",
    "    r\"\"\"Implementation of the :ref:`habitat.core.agent.Agent` interface that\n",
    "    uses :ref`habitat.tasks.nav.shortest_path_follower.ShortestPathFollower` utility class\n",
    "    for extracting the action on the shortest path to the goal.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, env: habitat.Env, goal_radius: float):\n",
    "        self.env = env\n",
    "        self.shortest_path_follower = ShortestPathFollower(\n",
    "            sim=cast(\"HabitatSim\", env.sim),\n",
    "            goal_radius=goal_radius,\n",
    "            return_one_hot=False,\n",
    "        )\n",
    "\n",
    "    def act(self, observations: \"Observations\") -> Union[int, np.ndarray]:\n",
    "        return self.shortest_path_follower.get_next_action(\n",
    "            cast(NavigationEpisode, self.env.current_episode).goals[0].position\n",
    "        )\n",
    "\n",
    "    def reset(self) -> None:\n",
    "        pass\n",
    "\n",
    "\n",
    "def example_top_down_map_measure():\n",
    "    # Create habitat config\n",
    "    config = habitat.get_config(\n",
    "        config_path=os.path.join(\n",
    "            dir_path,\n",
    "            \"habitat-lab/habitat/config/benchmark/nav/pointnav/pointnav_habitat_test.yaml\",\n",
    "        )\n",
    "    )\n",
    "    # Add habitat.tasks.nav.nav.TopDownMap and habitat.tasks.nav.nav.Collisions measures\n",
    "    with habitat.config.read_write(config):\n",
    "        config.habitat.task.measurements.update(\n",
    "            {\n",
    "                \"top_down_map\": TopDownMapMeasurementConfig(\n",
    "                    map_padding=3,\n",
    "                    map_resolution=1024,\n",
    "                    draw_source=True,\n",
    "                    draw_border=True,\n",
    "                    draw_shortest_path=True,\n",
    "                    draw_view_points=True,\n",
    "                    draw_goal_positions=True,\n",
    "                    draw_goal_aabbs=True,\n",
    "                    fog_of_war=FogOfWarConfig(\n",
    "                        draw=True,\n",
    "                        visibility_dist=5.0,\n",
    "                        fov=90,\n",
    "                    ),\n",
    "                ),\n",
    "                \"collisions\": CollisionsMeasurementConfig(),\n",
    "            }\n",
    "        )\n",
    "    # Create dataset\n",
    "    dataset = habitat.make_dataset(\n",
    "        id_dataset=config.habitat.dataset.type, config=config.habitat.dataset\n",
    "    )\n",
    "    # Create simulation environment\n",
    "    with habitat.Env(config=config, dataset=dataset) as env:\n",
    "        # Create ShortestPathFollowerAgent agent\n",
    "        agent = ShortestPathFollowerAgent(\n",
    "            env=env,\n",
    "            goal_radius=config.habitat.task.measurements.success.success_distance,\n",
    "        )\n",
    "        # Create video of agent navigating in the first episode\n",
    "        num_episodes = 1\n",
    "        for _ in range(num_episodes):\n",
    "            # Load the first episode and reset agent\n",
    "            observations = env.reset()\n",
    "            agent.reset()\n",
    "\n",
    "            # Get metrics\n",
    "            info = env.get_metrics()\n",
    "            # Concatenate RGB-D observation and topdowm map into one image\n",
    "            frame = observations_to_image(observations, info)\n",
    "\n",
    "            # Remove top_down_map from metrics\n",
    "            info.pop(\"top_down_map\")\n",
    "            # Overlay numeric metrics onto frame\n",
    "            frame = overlay_frame(frame, info)\n",
    "            # Add fame to vis_frames\n",
    "            vis_frames = [frame]\n",
    "\n",
    "            # Repeat the steps above while agent doesn't reach the goal\n",
    "            while not env.episode_over:\n",
    "                # Get the next best action\n",
    "                action = agent.act(observations)\n",
    "                if action is None:\n",
    "                    break\n",
    "\n",
    "                # Step in the environment\n",
    "                observations = env.step(action)\n",
    "                info = env.get_metrics()\n",
    "                frame = observations_to_image(observations, info)\n",
    "\n",
    "                info.pop(\"top_down_map\")\n",
    "                frame = overlay_frame(frame, info)\n",
    "                vis_frames.append(frame)\n",
    "\n",
    "            current_episode = env.current_episode\n",
    "            video_name = f\"{os.path.basename(current_episode.scene_id)}_{current_episode.episode_id}\"\n",
    "            # Create video from images and save to disk\n",
    "            images_to_video(\n",
    "                vis_frames, output_path, video_name, fps=6, quality=9\n",
    "            )\n",
    "            vis_frames.clear()\n",
    "            # Display video\n",
    "            vut.display_video(f\"{output_path}/{video_name}.mp4\")\n",
    "\n",
    "\n",
    "# [/example_4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    example_pointnav_draw_target_birdseye_view()\n",
    "    example_pointnav_draw_target_birdseye_view_agent_on_border()\n",
    "    example_get_topdown_map()\n",
    "    example_top_down_map_measure()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "nb_python//py:percent,notebooks//ipynb",
   "notebook_metadata_filter": "all"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
